Skip to content

fix(ninetoothed): bound RmsNorm autotune shapes#624

Draft
voltjia wants to merge 1 commit into
masterfrom
fix/ninetoothed-rms-norm-autotune
Draft

fix(ninetoothed): bound RmsNorm autotune shapes#624
voltjia wants to merge 1 commit into
masterfrom
fix/ninetoothed-rms-norm-autotune

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented May 28, 2026

Summary

  • Bound the synthetic auto-tuning shapes used by the NineToothed RmsNorm build configuration in src/ninetoothed/ops/rms_norm/build.py.
  • Keep runtime rank-2/3/4 RmsNorm coverage unchanged while avoiding extremely large build-time warmup tensors for rank-4 configurations.

Motivation

The RmsNorm NineToothed build currently lets every symbolic dimension use the generic build warmup size. For rank-4 RmsNorm this creates synthetic tensors with shape (256, 256, 256, 256) during auto-tuning, which can consume tens of GiB per tensor and fail with a CUDA illegal memory access during triton.testing.do_bench synchronization.

Closes #N/A

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • build / ci — build system or CI configuration
  • docs — documentation only
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Test Results on Supported Platforms

Platform Built pytest Result Notes / Hardware
NVIDIA Yes N/A Focused NineToothed RmsNorm generation passes; full NVIDIA build with WITH_NINETOOTHED=ON passed during diagnosis with this fix shape.
Iluvatar N/A N/A Not affected; NineToothed RmsNorm is NVIDIA-only.
MetaX N/A N/A Not affected; NineToothed RmsNorm is NVIDIA-only.
Cambricon N/A N/A Not affected; NineToothed RmsNorm is NVIDIA-only.
Moore N/A N/A Not affected; NineToothed RmsNorm is NVIDIA-only.
Ascend N/A N/A Not affected; NineToothed RmsNorm is NVIDIA-only.
Validation output
ruff format --check src/ninetoothed/ops/rms_norm/build.py
1 file already formatted

ruff check src/ninetoothed/ops/rms_norm/build.py
All checks passed!

python3 scripts/generate_ninetoothed_ops.py --ops rms_norm --output-dir /tmp/nt-rms-shapeoptions
# exit 0

# Diagnostic full NVIDIA build with the same fix shape:
END build status=0 seconds=1447

Benchmark / Performance Impact

N/A. This affects build-time auto-tuning inputs only; runtime RmsNorm dispatch coverage is unchanged.

Notes for Reviewers

  • The failure mode was reproduced by instrumenting AutoTuner._get_timing: the old build configuration passed three rank-4 float32 tensors with shape (256, 256, 256, 256) into the benchmark path.
  • The fix keeps the normalized dimension at 256 for build-time tuning but bounds synthetic batch dimensions to 1, matching the reduction structure of RmsNorm and avoiding unrealistic build-time allocation pressure.

Checklist

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • N/A. No public API changes are introduced.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • N/A. No new comments or error messages were added.
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • N/A. No new comments or error messages were added.

C++ Specific (if C++ files changed)

  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No C++ files changed.
  • N/A. No new operator is added.
  • N/A. No C++ files changed.

Python Specific (if Python files changed)

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (see .github/workflows/ruff.yml).
  • ruff format --check passes cleanly — if not, run ruff format and commit the result.
  • N/A. No comments were added.
  • N/A. No framework-specific messages were added.
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement (CONTRIBUTING.md §Python).
  • N/A. No docstrings were added.
  • Type hints are added / kept consistent with the surrounding code.

Testing

  • pytest was run locally on every supported platform that this PR can affect, and the results are recorded in the "Test Results" table above (CONTRIBUTING.md §Pull Requests).
  • N/A. Non-NVIDIA platforms are not affected by WITH_NINETOOTHED RmsNorm generation.
  • N/A. No new functionality is added.
  • N/A. No tests were added.
  • N/A. No tests were added.
  • N/A. No tests were added.
  • N/A. No tests were added.
  • N/A. The regression is build-time GPU memory pressure in a generated-code path; the focused generator command reproduces the affected path.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install .[dev] on at least one affected platform.
  • N/A. This PR does not change CMake compile command generation.
  • N/A. No new backend or device is added.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not broken.
  • ruff is green locally; no C++ files were changed for clang-format.
  • N/A. No new runtime dependency was added.

Documentation

  • N/A. No user-facing build flag or workflow changed.
  • N/A. No new operator, dispatch helper, or public utility is added.
  • N/A. No user-visible breaking change is introduced.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers have been committed.
  • N/A. No third-party code was added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant